{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import os\n",
    "import sys\n",
    "module_path = os.path.abspath(os.path.join('..'))\n",
    "if module_path not in sys.path:\n",
    "    sys.path.append(module_path)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%env OPENAI_API_KEY=<Enter you key here>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from PIL import Image\n",
    "from IPython.core.display import HTML\n",
    "from functools import partial\n",
    "\n",
    "from engine.utils import ProgramGenerator, ProgramInterpreter\n",
    "from prompts.nlvr import create_prompt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "interpreter = ProgramInterpreter(dataset='nlvr')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "prompter = partial(create_prompt,method='all')\n",
    "generator = ProgramGenerator(prompter=prompter)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "left_image = Image.open('../assets/camel1.png')\n",
    "left_image.thumbnail((640,640),Image.Resampling.LANCZOS)\n",
    "right_image = Image.open('../assets/camel2.png')\n",
    "right_image.thumbnail((640,640),Image.Resampling.LANCZOS)\n",
    "init_state = dict(\n",
    "    LEFT=left_image.convert('RGB'),\n",
    "    RIGHT=right_image.convert('RGB'),\n",
    ")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "statement = \"There is a woman in black in exactly one of the images and two camels in total.\"\n",
    "# statement = \"There are more camels than people.\"\n",
    "# statement = \"The camel is sitting in both images.\"\n",
    "# statement = \"There are no trees in the images.\"\n",
    "prog,_ = generator.generate(dict(statement=statement))\n",
    "print(prog)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result, prog_state, html_str = interpreter.execute(prog,init_state,inspect=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "result"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "HTML(html_str)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "visprog",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.10"
  },
  "orig_nbformat": 4
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
